Skip to content

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713

Open
cspades wants to merge 11 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp
Open

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
cspades wants to merge 11 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp

Conversation

@cspades
Copy link
Member

@cspades cspades commented Feb 26, 2026

Summary

  • Support (H/F)SDP2 x TP strided sharding, and DTensor FP8 parameters for Torch DCP checkpointing, across all TransformerEngineBaseModule(s).
    • Except GroupedLinear, pending FSDP2 standalone pipe-cleaning. All other modules under transformer_engine.pytorch.modules are supported.
    • FusibleOperation support is also a WIP, except for LayerNorm or RMSNorm which are TE modules.
  • Associated with BioNeMo-Recipes Llama3 TP: Enable TransformerEngine-backed Tensor Parallelism with Llama3. bionemo-framework#1483
    • Notably, TransformerEngine TP can be easily mixed with DTensor-based TP when unified by Torch DCP! In the Llama3 recipe, we use DTensor-based TP on the torch.nn.Embedding, TransformerEngine-based TP on the LM head, and weight-tie the LM head to the torch.nn.Embedding, which is why we do not need to call set_device_mesh for the LM head!
  • Credit to @pstjohn for coming up with this idea!

Usage / Documentation

(tp_mesh and weight_mesh can also be passed in TEModule.__init__.)

    def set_device_mesh(
        self,
        tp_mesh: Optional[DeviceMesh] = None,
        weight_mesh: Optional[DeviceMesh] = None,
    ) -> None:
        """
        Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor
        depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2.

        TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless
        integration with Torch DCP checkpointing. This method should only be invoked when
        using DTensor parameters, e.g. when using FSDP2 or DCP.

        When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically
        convert them into FSDP-TP strided or non-strided shards depending on the current
        sharding dimension and factor of the DTensor. When the sharding dimension of FSDP
        matches that of TP, FSDP uses a _StridedShard placement type instead of Shard.
        This experimental FSDP-TP logic presides in this FSDP2 initialization function:
        ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param``

        Parameters
        ----------
        tp_mesh : Optional[DeviceMesh]
            A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"].
            Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP.
        weight_mesh : Optional[DeviceMesh]
            A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required
            when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor
            parameters and if the DTensor DeviceMesh includes dimensions that do not
            shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard).
            For example:
                - device_mesh["dp"] for FSDP.
                - device_mesh["dp_cp"] if using CP ranks in FSDP.
                - device_mesh["dp_shard"] if using HSDP ("dp_replicate", "dp_shard").
                - device_mesh["tp"] if using TP.
                - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP.
        """

Details

DTensor Lifecycle in TransformerEngine

  • Initialization
    • __init__
      • TransformerEngine model parameters are initialized either on device or meta device with the appropriate tp_size and TP sharding strategy, e.g. parallel_mode and sequence_parallel.
    • TransformerEngineModule.set_device_mesh(tp_mesh, weight_mesh)
      • Converts parameters to DTensor with appropriate TP placement(s) based on the TP sharding strategy specified in __init__, using transformer_engine.pytorch.distributed._convert_param_to_dtensor_param.
        • tp_mesh is a 1-D DeviceMesh containing the TP ProcessGroup that will be registered with the TransformerEngine module.
        • weight_mesh is the 1-D DeviceMesh containing the ProcessGroup that shards TransformerEngine module weights, the flattened combination of groups such as FSDP and TP. Specifically, it excludes non-weight groups such as DP-Replicate when using HSDP or HSDP-TP and is mainly required for per-Tensor scaling recipes like Float8CurrentScaling.
      • Needs to be invoked prior to fully_shard (which responds to the TP placements) and prior to reset_parameters(defer_init=False), which quantizes parameters.
      • Can also be directly invoked during __init__(tp_mesh, weight_mesh) for supported TransformerEngine modules.
    • fully_shard shards the TransformerEngine model with FSDP2.
      • When fully_shard encounters TP sharding on dim=0, it will use a _StridedShard for DP. Put simply, this "pre-shards" the data prior to sharding on the current placement, followed by concatenating the pre-shards to get strided shards that will be re-sharded by the next placement. This effectively reverses the sharding order when processing the placements from left-to-right, and distributes shards as if we sharded on TP first, then FSDP, as required, even though DP appears before TP in the DeviceMesh and DTensor.placements. (See Appendix for visualization of this sharding strategy.)
    • reset_parameters is called if using meta device initialization.
  • Training
    • Pre-forward, FSDP2 all-gathers the sharded DTensor "main" weight that it registered during fully_shard. (Note that this essentially shares the same properties as the compute weight besides shape, and supporting tools such as FusedAdam must be used to properly handle high-precision main weights.)
      • When using FSDP2 x TP, the all-gathered Tensor is actually a TP-sharded DTensor, which deviates from the original FSDP2 paradigm where the all-gathered Tensor is fully-unsharded and the DTensor wrapping is discarded. To support these DTensor compute weights in TransformerEngine modules, we utilize transformer_engine.pytorch.distributed._extract_trainable_tensor_from_dtensor to localize the DTensor and also inherit requires_grad attribute from the DTensor parameter as the local Tensor has this un-set during DTensor.from_local(Tensor) for FP8 parameters specifically!
    • Post-backward, the Tensor gradient is converted to DTensor and attached to the DTensor.grad attribute. Handled by DTensor <> Tensor Autograd conversion functions, and in the case of FusibleOperation, casted during the backward implementation.

QuantizedTensor Storage

  • When both row and column data are None, we send untyped_storage() to a default 1-byte storage that unblocks DCP checkpoint loading assertions using this as a definition for "emptiness". This is because a storage of 0 bytes is a data_ptr() = nullptr and breaks DCP.
    • While untyped_storage is not used anywhere in TransformerEngine, it may break code that uses Storage to figure out if a Tensor is empty or not, as now QuantizedTensor storage will always be a 1-byte storage even when both row and column data are not set. Those checks would instead need to compare the storage size against 1 byte instead of 0 bytes.

Bugs

  • Fix bug where "shard" was the presumed weight sharding sub-mesh in the DTensor.device_mesh. Now, users can precisely specify their own custom weight-sharding DeviceMesh for per-tensor amax_reduction_group via the set_device_mesh(weight_mesh) API.
  • TransformerEngineBaseModule: self.quantizers = {"scaling_fwd": [], "scaling_bwd": []}

Testing

# TransformerEngine Main
[Rank 0] (after 1 iterations) memory (MB) | allocated: 23511.65 | max allocated: 25189.68 | reserved: 25678.00 | max reserved: 25678.00
 [2026-03-02 09:55:17.189564] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12715.7 | throughput per GPU (TFLOP/s/GPU): 530.6 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124915E+00 | loss scale: 1.0 | grad norm: 5.474 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:55:29.768521] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12578.7 | throughput per GPU (TFLOP/s/GPU): 536.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.143806E+00 | loss scale: 1.0 | grad norm: 5.366 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Post-DCP Modifications (This PR)
[Rank 0] (after 2 iterations) memory (MB) | allocated: 23511.65 | max allocated: 29783.24 | reserved: 25678.00 | max reserved: 31510.00
 [2026-03-02 09:29:36.550070] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12556.5 | throughput per GPU (TFLOP/s/GPU): 537.3 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124463E+00 | loss scale: 1.0 | grad norm: 5.471 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:29:49.216068] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12665.7 | throughput per GPU (TFLOP/s/GPU): 532.7 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.142863E+00 | loss scale: 1.0 | grad norm: 5.355 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • NOTE(@cspades): DelayedScaling has DCP save/load disparity issues, i.e. on the scale of +/-1 to the uint8 parameter checkpoint!

Appendix

_StridedShard - Using FSDP2 x TP Strided-Sharding

# (DP=4, TP=2)
(_StridedShard(dim=0, sf=2), Shard(dim=0))

┌───┬───┐
│ 0 │ 4 │ ← DP=0
├───┼───┤
│ 1 │ 5 │ ← DP=1
├───┼───┤          FSDP all-gather happens across the DP ranks,
│ 2 │ 6 │ ← DP=2   so we need to form the 0-3 and 4-7 TP shards!
├───┼───┤
│ 3 │ 7 │ ← DP=3
└───┴───┘
  ↑   ↑
TP=0 TP=1

When redistribute'ing a global DTensor to (_StridedShard(dim=0, sf=2), Shard(dim=0)), DTensor will perform the following steps:

  • Pre-shard the Tensor data with respect to the stride / shard factor, which is defined as the product of the parallelism sizes of all Shard placements to the right of _StridedShard. (In the above example, since TP=2, the factor is 2.)
    • [0 1 2 3 4 5 6 7] -> [0 1 2 3] and [4 5 6 7].
    • In the context of this PR and fully_shard, this has already been done via initializing the TransformerEngine module with TP and calling _convert_param_to_dtensor_param!
  • Shard the pre-shards for _StridedShard.
    • [0] [1] [2] [3] and [4] [5] [6] [7]
  • Concatenate the strided shards.
    • [0 4] [1 5] [2 6] [3 7], which are assigned to the _StridedShard ranks.
    • Note that this is very different if we did left-to-right-sharding, which would have given us [0 1] [2 3] [4 5] [6 7]!
  • Subsequently / finally, each strided shard is sharded on the Shard placement.
    • [0] [4] / [1] [5] / [2] [6] / [3] [7], which are assigned to the Shard ranks.
    • Note that this is very different if we did left-to-right sharding, which would have given us [0] [1] / [2] [3] / [4] [5] / [6] [7]!

PyTorch also supports the inverse / un-sharding of this redistribute, which is literally the inverse of these simple operations! (Though things get a bit more complicated with un-even shards from odd-numbered dimension sizes.)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR adds set_device_mesh(tp_mesh, weight_mesh) to all major TransformerEngineBaseModule subclasses (Linear, LayerNormLinear, LayerNormMLP, LayerNorm, RMSNorm, DotProductAttention, MultiheadAttention, TransformerLayer) and to the LayerNorm/RMSNorm FusibleOperations, enabling FSDP2×TP strided-shard (DTensor _StridedShard) compatible checkpointing via Torch DCP. It also includes a significant bugfix for the _LayerNormMLP backward pass and the FP8 amax_reduction_group fallback logic.

Key changes:

  • New API: set_device_mesh(tp_mesh, weight_mesh) on every TE module converts plain parameters to DTensors with the correct TP Shard/Replicate placement, and registers the weight_mesh's process group as amax_reduction_group for Float8CurrentScaling recipes.
  • _convert_param_to_dtensor_param / _extract_trainable_tensor_from_dtensor: Two new helpers in distributed.py. The latter uses a custom autograd function (_ToLocalIdentity) that preserves object identity with DTensor._local_tensor so FSDP2's in-place all-gather updates remain visible.
  • DTensor input localization in TransformerEngineBaseModule.pre_forward: inp.to_local() is applied before C++ kernels.
  • Bug fix: _LayerNormMLP backward was checking isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) (always False — quantizers are not storage objects) instead of isinstance(ctx.fc1_weight, QuantizedTensorStorage).
  • Bug fix: amax_reduction_group fallback no longer hardcodes the "shard" mesh dimension; it now uses the full DTensor's flat group and is only applied when not already set by set_device_mesh.
  • Block-scaling storage fix: _default_storage = torch.UntypedStorage(1) gives block-scaling quantized tensors a non-null storage pointer, fixing the FSDP2 crash for non-meta-device initialization.
  • Quantizer storage change: self.quantizers = {"scaling_fwd": [], "scaling_bwd": []} (dict → list) aligns with integer fp8_meta_index lookup semantics.
  • Tests: New standalone run_fsdp2_allgather.py test; run_fsdp2_model.py extended with a full DCP save/load round-trip including state-dict parity validation; test_torch_fsdp2.py parameterized to include (H/F)SDP-TP [NUM_PROCS//4, 2, 2] configuration.

Confidence Score: 3/5

  • This PR is a large, architecturally significant change — safe to proceed with minor issues addressed, but the args.sharding_dims None-guard in run_fsdp2_model.py (pre-existing flagged issue) and the CPU _default_storage device mismatch warrant attention before merge.
  • The core logic (DTensor wrapping, FSDP2-TP strided sharding, DCP round-trip) is well thought out and backed by Megatron-LM parity tests. The two previously-flagged blocking issues in the test script (args.sharding_dims null guard, CKPT_DIR f-string) remain unfixed. The _default_storage device mismatch (CPU storage on a CUDA tensor class) introduces a subtle but potentially latent incompatibility. The run_check=True in layer_norm.py/rmsnorm.py backward adds collective overhead on every backward pass for affected modules. These are all non-critical for correct DCP functionality but should be resolved before merge.
  • tests/pytorch/distributed/run_fsdp2_model.py (null-guard and f-string issues), transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py / mxfp8_tensor_storage.py / nvfp4_tensor_storage.py (_default_storage device), transformer_engine/pytorch/ops/basic/layer_norm.py and rmsnorm.py (run_check overhead).

Important Files Changed

Filename Overview
transformer_engine/pytorch/distributed.py Adds _convert_param_to_dtensor_param (wraps plain params as DTensors, copying __dict__ attributes) and _extract_trainable_tensor_from_dtensor (a custom autograd Function that preserves object identity between the returned tensor and DTensor._local_tensor for FSDP2 in-place update compatibility). Both helpers are used consistently throughout the module-level set_device_mesh implementations. Logic looks correct; _ToLocalIdentity.backward correctly uses run_check=False.
transformer_engine/pytorch/module/base.py Two key changes: (1) self.quantizers switched from {"scaling_fwd": {}, "scaling_bwd": {}} (dict-of-dicts) to {"scaling_fwd": [], "scaling_bwd": []} (dict-of-lists), fixing a bug where integer fp8_meta_index lookup was used on dict keys. (2) Adds DTensor localization of inp at the top of the pre-forward hook, ensuring TE C++ kernels receive plain tensors. Also replaces DTensor.from_local(...) + torch.nn.Parameter(...) with the more complete _convert_param_to_dtensor_param in reset_parameters. (3) Fixes the amax_reduction_group fallback to use device_mesh.get_group() (flat group) rather than the hardcoded "shard" mesh dimension, and guards this path behind quantizer.amax_reduction_group is None.
transformer_engine/pytorch/module/layernorm_mlp.py Adds set_device_mesh supporting FSDP2-TP strided sharding: fc1_weight→Shard(dim=0), fc2_weight→Shard(dim=1), biases and layer_norm→Replicate/Shard as appropriate. Also adds _get_bias_tensors and _get_layernorm_weight_and_bias helpers that extract local tensors from DTensors before passing to C++ kernels. Contains a meaningful bug fix: the _LayerNormMLP backward check isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) is corrected to isinstance(ctx.fc1_weight, QuantizedTensorStorage) — the former always evaluated to False since quantizers are not QuantizedTensorStorage instances.
transformer_engine/pytorch/module/linear.py Adds set_device_mesh converting weights to DTensors with Shard(dim=0) for column-parallel, Shard(dim=1) for row-parallel, and Replicate for no-TP mode. Adds _get_bias_tensors helper and refactors _get_weight_and_bias_tensors to use helpers. _set_tensor_parallel_attributes is extracted from reset_parameters to keep TP attribute logic separate from parameter initialization.
transformer_engine/pytorch/module/layernorm_linear.py Adds set_device_mesh with appropriate Shard/Replicate placements for weights, biases, and layer_norm parameters. Adds _get_bias_tensors and _get_layernorm_weight_and_bias helpers. _set_tensor_parallel_attributes refactored out of reset_parameters. All forward and ONNX paths updated to use the new getter helpers.
transformer_engine/pytorch/ops/basic/layer_norm.py DTensor support added to both forward and backward ops. Forward extracts local tensors via weight.to_local() (simple call, no identity preservation needed for forward). Backward wraps grad_weight and grad_bias back into DTensors using DTensor.from_local. However, run_check=False is not passed, triggering a collective metadata consistency check on every backward pass — unnecessary overhead since metadata is deterministic. Same issue exists in rmsnorm.py.
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py Adds _default_storage: torch.UntypedStorage (1 byte, CPU) to provide a non-null storage for block-scaling quantized tensors when no actual FP8 data exists. This fixes the FSDP2 crash from data_ptr() == 0, removing the previous xfail for non-meta-device initialization. However the storage is created on CPU rather than on the tensor's CUDA device, unlike the previous torch.UntypedStorage(0, device=self.device). Same applies to mxfp8_tensor_storage.py and nvfp4_tensor_storage.py.
tests/pytorch/distributed/run_fsdp2_model.py Major extension adding a full DCP checkpoint save/load round-trip test after training. Adds AppState(Stateful) that evicts _extra_state from the model checkpoint and clears empty optimizer states. The checkpoint path and parity validation logic is thorough. Contains the previously-flagged if len(args.sharding_dims) >= 3: guard outside the if args.sharding_dims: block (line 163) and the args.sharding_dims iteration in the CKPT_DIR f-string.
tests/pytorch/distributed/run_fsdp2_allgather.py New standalone allgather correctness test extracted and generalized from the old test_fp8_fsdp2_allgather in run_fsdp2_model.py. Handles FSDP, HSDP, and FSDP-TP configurations. Uses --sharding-dims required=True so the None-guard issue present in run_fsdp2_model.py doesn't apply here. Warm-up training step ensures FSDP2 lazy state is initialized before the allgather test. Well-structured.
transformer_engine/pytorch/transformer.py Adds tp_mesh and weight_mesh parameters to TransformerLayer.__init__, propagating them to self_attention, inter_attention, and layernorm_mlp. Adds a set_device_mesh method that validates the TP mesh, sets the TP group, and recursively calls set_device_mesh on all TE children.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds untyped_storage() override (parallel to MXFP8Tensor), _default_storage for non-null storage support, and DTensor out-parameter unwrapping in the all-gather hook. Minor: the untyped_storage docstring incorrectly says "MXFP8Tensor" instead of "NVFP4Tensor".

Sequence Diagram

sequenceDiagram
    participant User
    participant TEModule as TE Module (Linear/LayerNormLinear/etc.)
    participant Base as TransformerEngineBaseModule
    participant Dist as distributed.py
    participant FSDP2 as fully_shard (FSDP2)
    participant DCP as Torch DCP

    User->>TEModule: __init__(tp_size, tp_mesh, weight_mesh)
    TEModule->>TEModule: set_device_mesh(tp_mesh, weight_mesh)
    TEModule->>Dist: _convert_param_to_dtensor_param(param, tp_mesh, Shard/Replicate)
    Dist-->>TEModule: DTensor(local_param, placement)
    TEModule->>TEModule: reset_parameters(defer_init=True/False)

    User->>FSDP2: fully_shard(model, mesh[dp_dims])
    Note over FSDP2: Detects DTensor Shard(dim=0) on DP-matching dim<br/>→ uses _StridedShard placement for FSDP-TP

    User->>TEModule: reset_parameters() [meta device only]
    Base->>Base: quantize param (FP8)
    Base->>Dist: _convert_param_to_dtensor_param(fp8_param, dtensor.device_mesh, ...)
    Dist-->>Base: DTensor(fp8_param, same placement)

    rect rgb(200, 230, 255)
        Note over TEModule,FSDP2: Training Forward Pass
        FSDP2->>FSDP2: all-gather sharded DTensor weight
        Note over FSDP2: TP-sharded DTensor remains after all-gather
        TEModule->>Dist: _extract_trainable_tensor_from_dtensor(dtensor)
        Dist-->>TEModule: local Tensor (identity-preserved via _ToLocalIdentity)
        TEModule->>TEModule: C++ kernel (plain Tensor)
    end

    rect rgb(255, 230, 200)
        Note over TEModule,FSDP2: DCP Checkpoint Save/Load
        User->>DCP: save({"app": AppState(model, optimizer)})
        Note over DCP: AppState.state_dict() evicts _extra_state,<br/>clears empty optimizer states for empty params
        DCP-->>User: checkpoint written

        User->>DCP: load({"app": AppState(model, optimizer)})
        Note over DCP: set_state_dict(strict=False) ignores<br/>_extra_state and empty optimizer entries
        DCP-->>User: checkpoint restored
    end
Loading

Comments Outside Diff (3)

  1. transformer_engine/pytorch/ops/basic/layer_norm.py, line 265-272 (link)

    DTensor.from_local with implicit run_check=True on every backward pass

    DTensor.from_local defaults to run_check=True, which performs a collective metadata consistency check (verifying all ranks agree on global shape/strides/dtype) on every single backward pass. Since the metadata is deterministically derived from self.weight, this check is unnecessary and adds communication overhead in the backward pass.

    Use run_check=False here (and similarly for grad_bias on line 274, and in the equivalent code in rmsnorm.py). The _ToLocalIdentity.backward helper in distributed.py already sets run_check=False as a pattern to follow:

  2. transformer_engine/pytorch/ops/basic/rmsnorm.py, line 248-256 (link)

    DTensor.from_local with implicit run_check=True on every backward pass

    Same as the layer_norm.py backward: run_check=True (the default) triggers a collective sanity-check on every backward pass. Since metadata is deterministic, this is unnecessary overhead.

  3. transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py, line 692 (link)

    _default_storage is created on CPU, but the tensor lives on CUDA

    torch.UntypedStorage(1) without a device argument creates a CPU storage, while Float8BlockwiseQTensorStorage and its sibling classes (MXFP8TensorStorage, NVFP4TensorStorage) hold CUDA tensors. The previous implementation explicitly matched the device: torch.UntypedStorage(0, device=self.device).

    While the tests now pass (FSDP2 apparently only checks data_ptr() != 0), this could cause subtle issues in future PyTorch versions or with code that inspects storage.device — e.g. PyTorch internal validation checks that compare a tensor's storage device to the tensor's device.

    Consider creating the default storage on the same device as the tensor:

    The same applies to mxfp8_tensor_storage.py line 713 and nvfp4_tensor_storage.py line 134.

Last reviewed commit: 7ea9ab6

@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from 4ec2947 to dbb9d14 Compare March 4, 2026 18:10
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from fcdd5bd to c912f5b Compare March 5, 2026 16:06
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch 5 times, most recently from bc82f02 to 267f1df Compare March 10, 2026 01:30
@vthumbe1503
Copy link
Collaborator

/te-ci L1 pytorch

@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch 4 times, most recently from f0b3cae to af7362a Compare March 12, 2026 15:26
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch 3 times, most recently from 5d473b8 to 9435382 Compare March 12, 2026 21:10
cspades and others added 8 commits March 16, 2026 08:52
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…ess.

Signed-off-by: Cory Ye <cye@nvidia.com>
… are still model parity tested.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades
Copy link
Member Author

cspades commented Mar 16, 2026

/te-ci L1 pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants